{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "517af4d9",
   "metadata": {},
   "source": [
    "# Agentic RAG\n",
    "\n",
    "에이전트(Agent) 는 검색 도구를 사용할지 여부를 결정해야 할 때 유용합니다. 에이전트와 관련된 내용은 [Agent](https://wikidocs.net/233782) 페이지를 참고하세요.\n",
    "\n",
    "검색 에이전트를 구현하기 위해서는 `LLM`에 검색 도구에 대한 접근 권한을 부여하기만 하면 됩니다.\n",
    "\n",
    "이를 [LangGraph](https://langchain-ai.github.io/langgraph/)에 통합할 수 있습니다.\n",
    "\n",
    "![langgraph-agentic-rag](assets/langgraph-agentic-rag.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "532dc8f5",
   "metadata": {},
   "source": [
    "## 환경 설정"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4def9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# API 키를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API 키 정보 로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1966105",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install -qU langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH17-LangGraph-Structures\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ee354d2",
   "metadata": {},
   "source": [
    "## 기본 PDF 기반 Retrieval Chain 생성\n",
    "\n",
    "여기서는 PDF 문서를 기반으로 Retrieval Chain 을 생성합니다. 가장 단순한 구조의 Retrieval Chain 입니다.\n",
    "\n",
    "단, LangGraph 에서는 Retirever 와 Chain 을 따로 생성합니다. 그래야 각 노드별로 세부 처리를 할 수 있습니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "- 이전 튜토리얼에서 다룬 내용이므로, 자세한 설명은 생략합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "72d62a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rag.pdf import PDFRetrievalChain\n",
    "\n",
    "# PDF 문서를 로드합니다.\n",
    "pdf = PDFRetrievalChain([\"data/SPRI_AI_Brief_2023년12월호_F.pdf\"]).create_chain()\n",
    "\n",
    "# retriever와 chain을 생성합니다.\n",
    "pdf_retriever = pdf.retriever\n",
    "pdf_chain = pdf.chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8575f296",
   "metadata": {},
   "source": [
    "그 다음 `retriever_tool` 도구를 생성합니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "`document_prompt` 는 검색된 문서를 표현하는 프롬프트입니다.\n",
    "\n",
    "**사용가능한 키** \n",
    "\n",
    "- `page_content`\n",
    "- `metadata` 의 키: (예시) `source`, `page`\n",
    "\n",
    "**사용예시**\n",
    "\n",
    "`\"<document><context>{page_content}</context><metadata><source>{source}</source><page>{page}</page></metadata></document>\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c8d6f279",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools.retriever import create_retriever_tool\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "\n",
    "# PDF 문서를 기반으로 검색 도구 생성\n",
    "retriever_tool = create_retriever_tool(\n",
    "    pdf_retriever,\n",
    "    \"pdf_retriever\",\n",
    "    \"Search and return information about SPRI AI Brief PDF file. It contains useful information on recent AI trends. The document is published on Dec 2023.\",\n",
    "    document_prompt=PromptTemplate.from_template(\n",
    "        \"<document><context>{page_content}</context><metadata><source>{source}</source><page>{page}</page></metadata></document>\"\n",
    "    ),\n",
    ")\n",
    "\n",
    "# 생성된 검색 도구를 도구 리스트에 추가하여 에이전트에서 사용 가능하도록 설정\n",
    "tools = [retriever_tool]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b169f21",
   "metadata": {},
   "source": [
    "## Agent 상태\n",
    "\n",
    "그래프를 정의하겠습니다.\n",
    "\n",
    "각 노드에 전달되는 `state` 객체입니다.\n",
    "\n",
    "상태는 `messages` 목록으로 구성됩니다.\n",
    "\n",
    "그래프의 각 노드는 이 목록에 내용을 추가합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "55abb4fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated, Sequence, TypedDict\n",
    "from langchain_core.messages import BaseMessage\n",
    "from langgraph.graph.message import add_messages\n",
    "\n",
    "\n",
    "# 에이전트 상태를 정의하는 타입 딕셔너리, 메시지 시퀀스를 관리하고 추가 동작 정의\n",
    "class AgentState(TypedDict):\n",
    "    # add_messages reducer 함수를 사용하여 메시지 시퀀스를 관리\n",
    "    messages: Annotated[Sequence[BaseMessage], add_messages]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f93141",
   "metadata": {},
   "source": [
    "## 노드와 엣지\n",
    "\n",
    "에이전트 기반 RAG 그래프는 다음과 같이 구성될 수 있습니다.\n",
    "\n",
    "* **상태**는 메시지들의 집합입니다\n",
    "* 각 **노드**는 상태를 업데이트(추가)합니다\n",
    "* **조건부 엣지**는 다음에 방문할 노드를 결정합니다"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e17b98e",
   "metadata": {},
   "source": [
    "간단한 채점기(Grader)를 만들어 보겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4d85971f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "from langchain import hub\n",
    "from langchain_core.messages import HumanMessage\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from pydantic import BaseModel, Field\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langgraph.prebuilt import tools_condition\n",
    "from langchain_teddynote.models import get_model_name, LLMs\n",
    "\n",
    "# 최신 모델이름 가져오기\n",
    "MODEL_NAME = get_model_name(LLMs.GPT4)\n",
    "\n",
    "\n",
    "# 데이터 모델 정의\n",
    "class grade(BaseModel):\n",
    "    \"\"\"A binary score for relevance checks\"\"\"\n",
    "\n",
    "    binary_score: str = Field(\n",
    "        description=\"Response 'yes' if the document is relevant to the question or 'no' if it is not.\"\n",
    "    )\n",
    "\n",
    "\n",
    "def grade_documents(state) -> Literal[\"generate\", \"rewrite\"]:\n",
    "    # LLM 모델 초기화\n",
    "    model = ChatOpenAI(temperature=0, model=MODEL_NAME, streaming=True)\n",
    "\n",
    "    # 구조화된 출력을 위한 LLM 설정\n",
    "    llm_with_tool = model.with_structured_output(grade)\n",
    "\n",
    "    # 프롬프트 템플릿 정의\n",
    "    prompt = PromptTemplate(\n",
    "        template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
    "        Here is the retrieved document: \\n\\n {context} \\n\\n\n",
    "        Here is the user question: {question} \\n\n",
    "        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
    "        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\",\n",
    "        input_variables=[\"context\", \"question\"],\n",
    "    )\n",
    "\n",
    "    # llm + tool 바인딩 체인 생성\n",
    "    chain = prompt | llm_with_tool\n",
    "\n",
    "    # 현재 상태에서 메시지 추출\n",
    "    messages = state[\"messages\"]\n",
    "\n",
    "    # 가장 마지막 메시지 추출\n",
    "    last_message = messages[-1]\n",
    "\n",
    "    # 원래 질문 추출\n",
    "    question = messages[0].content\n",
    "\n",
    "    # 검색된 문서 추출\n",
    "    retrieved_docs = last_message.content\n",
    "\n",
    "    # 관련성 평가 실행\n",
    "    scored_result = chain.invoke({\"question\": question, \"context\": retrieved_docs})\n",
    "\n",
    "    # 관련성 여부 추출\n",
    "    score = scored_result.binary_score\n",
    "\n",
    "    # 관련성 여부에 따른 결정\n",
    "    if score == \"yes\":\n",
    "        print(\"==== [DECISION: DOCS RELEVANT] ====\")\n",
    "        return \"generate\"\n",
    "\n",
    "    else:\n",
    "        print(\"==== [DECISION: DOCS NOT RELEVANT] ====\")\n",
    "        print(score)\n",
    "        return \"rewrite\"\n",
    "\n",
    "\n",
    "def agent(state):\n",
    "    # 현재 상태에서 메시지 추출\n",
    "    messages = state[\"messages\"]\n",
    "\n",
    "    # LLM 모델 초기화\n",
    "    model = ChatOpenAI(temperature=0, streaming=True, model=MODEL_NAME)\n",
    "\n",
    "    # retriever tool 바인딩\n",
    "    model = model.bind_tools(tools)\n",
    "\n",
    "    # 에이전트 응답 생성\n",
    "    response = model.invoke(messages)\n",
    "\n",
    "    # 기존 리스트에 추가되므로 리스트 형태로 반환\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "\n",
    "def rewrite(state):\n",
    "    print(\"==== [QUERY REWRITE] ====\")\n",
    "    # 현재 상태에서 메시지 추출\n",
    "    messages = state[\"messages\"]\n",
    "    # 원래 질문 추출\n",
    "    question = messages[0].content\n",
    "\n",
    "    # 질문 개선을 위한 프롬프트 구성\n",
    "    msg = [\n",
    "        HumanMessage(\n",
    "            content=f\"\"\" \\n \n",
    "    Look at the input and try to reason about the underlying semantic intent / meaning. \\n \n",
    "    Here is the initial question:\n",
    "    \\n ------- \\n\n",
    "    {question} \n",
    "    \\n ------- \\n\n",
    "    Formulate an improved question: \"\"\",\n",
    "        )\n",
    "    ]\n",
    "\n",
    "    # LLM 모델로 질문 개선\n",
    "    model = ChatOpenAI(temperature=0, model=MODEL_NAME, streaming=True)\n",
    "    # Query-Transform 체인 실행\n",
    "    response = model.invoke(msg)\n",
    "\n",
    "    # 재작성된 질문 반환\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "\n",
    "def generate(state):\n",
    "    # 현재 상태에서 메시지 추출\n",
    "    messages = state[\"messages\"]\n",
    "\n",
    "    # 원래 질문 추출\n",
    "    question = messages[0].content\n",
    "\n",
    "    # 가장 마지막 메시지 추출\n",
    "    docs = messages[-1].content\n",
    "\n",
    "    # RAG 프롬프트 템플릿 가져오기\n",
    "    prompt = hub.pull(\"teddynote/rag-prompt\")\n",
    "\n",
    "    # LLM 모델 초기화\n",
    "    llm = ChatOpenAI(model_name=MODEL_NAME, temperature=0, streaming=True)\n",
    "\n",
    "    # RAG 체인 구성\n",
    "    rag_chain = prompt | llm | StrOutputParser()\n",
    "\n",
    "    # 답변 생성 실행\n",
    "    response = rag_chain.invoke({\"context\": docs, \"question\": question})\n",
    "    return {\"messages\": [response]}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f481eb52",
   "metadata": {},
   "source": [
    "## 그래프\n",
    "\n",
    "* `call_model` 에이전트로 시작합니다\n",
    "* 에이전트가 함수를 호출할지 결정합니다\n",
    "* 함수 호출을 결정한 경우, 도구(retriever)를 호출하기 위한 `action`을 실행합니다\n",
    "* 도구의 출력값을 메시지(`state`)에 추가하여 에이전트를 호출합니다"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d42a95a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END, StateGraph, START\n",
    "from langgraph.prebuilt import ToolNode\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "# AgentState 기반 상태 그래프 워크플로우 초기화\n",
    "workflow = StateGraph(AgentState)\n",
    "\n",
    "# 노드 정의\n",
    "workflow.add_node(\"agent\", agent)  # 에이전트 노드\n",
    "retrieve = ToolNode([retriever_tool])\n",
    "workflow.add_node(\"retrieve\", retrieve)  # 검색 노드\n",
    "workflow.add_node(\"rewrite\", rewrite)  # 질문 재작성 노드\n",
    "workflow.add_node(\"generate\", generate)  # 관련 문서 확인 후 응답 생성 노드\n",
    "\n",
    "# 엣지 연결\n",
    "workflow.add_edge(START, \"agent\")\n",
    "\n",
    "# 검색 여부 결정을 위한 조건부 엣지 추가\n",
    "workflow.add_conditional_edges(\n",
    "    \"agent\",\n",
    "    # 에이전트 결정 평가\n",
    "    tools_condition,\n",
    "    {\n",
    "        # 조건 출력을 그래프 노드에 매핑\n",
    "        \"tools\": \"retrieve\",\n",
    "        END: END,\n",
    "    },\n",
    ")\n",
    "\n",
    "# 액션 노드 실행 후 처리될 엣지 정의\n",
    "workflow.add_conditional_edges(\n",
    "    \"retrieve\",\n",
    "    # 문서 품질 평가\n",
    "    grade_documents,\n",
    ")\n",
    "workflow.add_edge(\"generate\", END)\n",
    "workflow.add_edge(\"rewrite\", \"agent\")\n",
    "\n",
    "# 그래프 컴파일\n",
    "graph = workflow.compile(checkpointer=MemorySaver())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd907104",
   "metadata": {},
   "source": [
    "그래프를 시각화합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df716e28",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.graphs import visualize_graph\n",
    "\n",
    "visualize_graph(graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9659e932",
   "metadata": {},
   "source": [
    "## 그래프 실행"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4b1efeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnableConfig\n",
    "from langchain_teddynote.messages import stream_graph, invoke_graph, random_uuid\n",
    "\n",
    "# config 설정(재귀 최대 횟수, thread_id)\n",
    "config = RunnableConfig(recursion_limit=10, configurable={\"thread_id\": random_uuid()})\n",
    "\n",
    "# 사용자의 에이전트 메모리 유형에 대한 질문을 포함하는 입력 데이터 구조 정의\n",
    "inputs = {\n",
    "    \"messages\": [\n",
    "        (\"user\", \"삼성전자가 개발한 생성형 AI 의 이름은?\"),\n",
    "    ]\n",
    "}\n",
    "\n",
    "# 그래프 실행\n",
    "invoke_graph(graph, inputs, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da64d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 그래프 스트리밍 출력\n",
    "stream_graph(graph, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8577bd32",
   "metadata": {},
   "source": [
    "아래는 문서 검색이 **불필요한** 질문의 예시입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd1bc454",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 문서 검색이 불가능한 질문 예시\n",
    "inputs = {\n",
    "    \"messages\": [\n",
    "        (\"user\", \"대한민국의 수도는?\"),\n",
    "    ]\n",
    "}\n",
    "\n",
    "# 그래프 실행\n",
    "stream_graph(graph, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ac602bd",
   "metadata": {},
   "source": [
    "아래는 임의로 **문서 검색이 불가능한** 질문 예시입니다.\n",
    "\n",
    "따라서, 문서를 지속적으로 검색하는 과정에서 `GraphRecursionError` 가 발생하였습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8ac07b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.errors import GraphRecursionError\n",
    "\n",
    "# 문서 검색이 불가능한 질문 예시\n",
    "inputs = {\n",
    "    \"messages\": [\n",
    "        (\"user\", \"테디노트의 랭체인 튜토리얼에 대해서 알려줘\"),\n",
    "    ]\n",
    "}\n",
    "\n",
    "try:\n",
    "    # 그래프 실행\n",
    "    stream_graph(graph, inputs, config, [\"agent\", \"rewrite\", \"generate\"])\n",
    "except GraphRecursionError as recursion_error:\n",
    "    print(f\"GraphRecursionError: {recursion_error}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c093f5c",
   "metadata": {},
   "source": [
    "다음 튜토리얼에서는 이를 해결하는 방법을 다룹니다."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain-kr-lwwSZlnu-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
